class Solution(object):
    def findNearestRightNode(self, root, u):
        st = [root]
        while st:
            if u in st:
                index = st.index(u)
                if index < len(st) - 1:
                    return st[index + 1]
                else:
                    return None
            next_list = []
            for node in st:
                if node.left:
                    next_list.append(node.left)
                if node.right:
                    next_list.append(node.right)
            st = next_list

